package controllers

import (
	"fmt"
	"github.com/hanc00l/nemo_go/pkg/db"
	"github.com/hanc00l/nemo_go/pkg/logging"
	"github.com/hanc00l/nemo_go/pkg/task/pocscan"
)

type VulController struct {
	BaseController
}

// vulRequestParam 请求参数
type vulRequestParam struct {
	DatableRequestParam
	Source    string `form:"vul_source"`
	Target    string `form:"vul_target"`
	PocFile   string `form:"vul_poc_file"`
	DateDelta int    `form:"date_delta"`
}

type VulnerabilityData struct {
	Id          int    `json:"id"`
	Index       int    `json:"index"`
	Target      string `json:"target"`
	Url         string `json:"url"`
	PocFile     string `json:"poc_file"`
	Source      string `json:"source"`
	CreateTime  string `json:"create_datetime"`
	UpdateTime  string `json:"update_datetime"`
	WorkspaceId int    `json:"workspace"`
}

type VulnerabilityInfo struct {
	Id         int
	Target     string
	Url        string
	PocFile    string
	Source     string
	Extra      string
	CreateTime string
	UpdateTime string
	Workspace  string
}

func (c *VulController) IndexAction() {
	c.Layout = "base.html"
	c.TplName = "vulnerability-list.html"
}

// ListAction 漏洞列表的数据
func (c *VulController) ListAction() {
	defer c.ServeJSON()

	req := vulRequestParam{}
	err := c.ParseForm(&req)
	if err != nil {
		logging.RuntimeLog.Error(err)
		logging.CLILog.Error(err)
	}
	c.validateRequestParam(&req)
	resp := c.getVulnerabilityListData(req)
	c.Data["json"] = resp
}

// InfoAction 显示一个漏洞的详情
func (c *VulController) InfoAction() {
	var vulInfo VulnerabilityInfo

	vulId, err := c.GetInt("id")
	if err != nil {
		logging.RuntimeLog.Error(err)
		logging.CLILog.Error(err)
	} else {
		vulInfo = getVulnerabilityInfo(vulId)
	}
	if c.IsServerAPI {
		c.Data["json"] = vulInfo
		c.ServeJSON()
	} else {
		c.Data["vul_info"] = vulInfo
		c.Layout = "base.html"
		c.TplName = "vulnerability-info.html"
	}
}

// DeleteAction 删除一个记录
func (c *VulController) DeleteAction() {
	defer c.ServeJSON()
	if c.CheckMultiAccessRequest([]RequestRole{SuperAdmin, Admin}, false) == false {
		c.FailedStatus("当前用户权限不允许！")
		return
	}

	id, err := c.GetInt("id")
	if err != nil {
		logging.RuntimeLog.Error(err)
		logging.CLILog.Error(err)
		c.FailedStatus(err.Error())
		return
	}
	vul := db.Vulnerability{Id: id}
	c.MakeStatusResponse(vul.Delete())
}

// LoadXrayPocFileAction 获取xray的pocfile列表
func (c *VulController) LoadXrayPocFileAction() {
	defer c.ServeJSON()

	p := pocscan.NewXray(pocscan.Config{})
	pocType := c.GetString("type", "default")
	if pocType == "custom" {
		c.Data["json"] = p.LoadPocFile()
	} else {
		c.Data["json"] = p.LoadDefaultPocFile()
	}
}

// LoadNucleiPocFileAction 获取Nuclei的pocfile列表
func (c *VulController) LoadNucleiPocFileAction() {
	n := pocscan.NewNuclei(pocscan.Config{})
	c.Data["json"] = n.LoadPocFile()
	c.ServeJSON()
}

// validateRequestParam 校验请求的参数
func (c *VulController) validateRequestParam(req *vulRequestParam) {
	if req.Length <= 0 {
		req.Length = 50
	}
	if req.Start < 0 {
		req.Start = 0
	}
}

// getSearchMap 根据查询参数生成查询条件
func (c *VulController) getSearchMap(req vulRequestParam) (searchMap map[string]interface{}) {
	searchMap = make(map[string]interface{})

	workspaceId := c.GetCurrentWorkspace()
	if workspaceId > 0 {
		searchMap["workspace_id"] = workspaceId
	}
	if req.Target != "" {
		searchMap["target"] = req.Target
	}
	if req.PocFile != "" {
		searchMap["poc_file"] = req.PocFile
	}
	if req.Source != "" {
		searchMap["source"] = req.Source
	}
	if req.DateDelta > 0 {
		searchMap["date_delta"] = req.DateDelta
	}
	return
}

// getVulnerabilityListData 获取列显示的数据
func (c *VulController) getVulnerabilityListData(req vulRequestParam) (resp DataTableResponseData) {
	vul := db.Vulnerability{}
	searchMap := c.getSearchMap(req)
	startPage := req.Start/req.Length + 1
	results, total := vul.Gets(searchMap, startPage, req.Length)
	for i, vulRow := range results {
		v := VulnerabilityData{}
		v.Id = vulRow.Id
		v.Index = req.Start + i + 1
		v.Target = vulRow.Target
		v.Url = vulRow.Url
		v.PocFile = vulRow.PocFile
		v.Source = vulRow.Source
		v.CreateTime = FormatDateTime(vulRow.CreateDatetime)
		v.UpdateTime = FormatDateTime(vulRow.UpdateDatetime)
		v.WorkspaceId = vulRow.WorkspaceId
		resp.Data = append(resp.Data, v)
	}
	resp.Draw = req.Draw
	resp.RecordsTotal = total
	resp.RecordsFiltered = total
	if resp.Data == nil {
		resp.Data = make([]interface{}, 0)
	}
	return
}

// getVulnerabilityInfo 获取一个漏洞的详情
func getVulnerabilityInfo(vulId int) (r VulnerabilityInfo) {
	vul := db.Vulnerability{Id: vulId}
	if !vul.Get() {
		return r
	}
	r.Id = vulId
	r.Url = vul.Url
	r.Target = vul.Target
	r.Source = vul.Source
	r.PocFile = vul.PocFile
	r.Extra = vul.Extra
	r.CreateTime = FormatDateTime(vul.CreateDatetime)
	r.UpdateTime = FormatDateTime(vul.UpdateDatetime)
	r.Workspace = fmt.Sprintf("%d", vul.WorkspaceId)

	return
}
